import numpy as np
import argparse
import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
import datetime
import matplotlib.pyplot as plt
from utils import data_loader, model_evaluation, set_random_seed

def get_parser():
    "Get default arguments"
    parser = argparse.ArgumentParser(description='Models config parser')
    parser.add_argument('--seed', type=int, default=3001, help="random seed")
    parser.add_argument('--input_size', type=int, default=24)
    parser.add_argument('--class_num', type=int, default=4)
    parser.add_argument('--data_dir', type=str, default='Datasets')
    parser.add_argument('--train_set', type=str, default='Datasets_C2.pkl')
    parser.add_argument('--test_set', type=str, default='Datasets_C1.pkl')
    parser.add_argument('--Model_name', type=str, default='RFmodel')
    return parser

def model_select(args, X_train, y_train, X_test):
    if args.Model_name == 'SVCmodel':
        proposed_model = sklearn.svm.SVC(C=0.5, kernel='linear')
        proposed_model.fit(X_train, np.ravel(y_train))
    elif args.Model_name == 'DTmodel':
        proposed_model = sklearn.tree.DecisionTreeClassifier()
        proposed_model.fit(X_train, np.ravel(y_train))
    elif args.Model_name == 'RFmodel':
        proposed_model = sklearn.ensemble.RandomForestClassifier(n_estimators=200)
        proposed_model.fit(X_train, np.ravel(y_train))
    elif args.Model_name == 'KNNmodel':
        proposed_model = sklearn.neighbors.KNeighborsClassifier(algorithm='kd_tree')
        proposed_model.fit(X_train, np.ravel(y_train))
    elif args.Model_name == 'MLPmodel':
        proposed_model = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(100, 100))
        proposed_model.fit(X_train, np.ravel(y_train))

    y_pred = proposed_model.predict(X_test)
    return proposed_model, y_pred


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    set_random_seed(args.seed)
    X_train, y_train, X_test, y_test = data_loader(args)
    proposed_model, y_pred = model_select(args, X_train, y_train, X_test)
    now_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    model_evaluation(y_pred, y_test, args, now_time)
